{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 0, 1, 0, 0, 0, 0],\n",
       "        [1, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 1, 0, 0, 1, 1, 0],\n",
       "        [1, 1, 0, 1, 1, 1, 0],\n",
       "        [0, 2, 2, 0, 2, 1, 1],\n",
       "        [2, 1, 1, 1, 0, 0, 1],\n",
       "        [1, 1, 0, 0, 1, 1, 1],\n",
       "        [2, 0, 0, 2, 2, 0, 1],\n",
       "        [0, 0, 1, 1, 1, 0, 1]]),\n",
       " array([[0, 0, 1, 0, 0, 0, 0],\n",
       "        [2, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 1, 0, 0, 1, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 0, 1],\n",
       "        [2, 2, 2, 2, 2, 0, 1],\n",
       "        [2, 0, 0, 2, 2, 1, 1],\n",
       "        [0, 1, 0, 1, 0, 0, 1]]))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "#加载数据集\n",
    "def load_data():\n",
    "    with open('决策树数据.txt') as fr:\n",
    "        lines = fr.readlines()\n",
    "\n",
    "    x = np.empty((len(lines), 7), dtype=int)\n",
    "\n",
    "    for i in range(len(lines)):\n",
    "        line = lines[i].strip().split(',')\n",
    "        x[i] = line\n",
    "\n",
    "    test_x = x[10:]\n",
    "    x = x[:10]\n",
    "\n",
    "    return x, test_x\n",
    "\n",
    "\n",
    "x, test_x = load_data()\n",
    "x, test_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Node col=0\n",
      "Leaf y=1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(None, None)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#创建节点和叶子对象,用来构建树\n",
    "class Node():\n",
    "    def __init__(self, col):\n",
    "        self.col = col\n",
    "        self.children = {}\n",
    "\n",
    "    def __str__(self):\n",
    "        return 'Node col=%d' % self.col\n",
    "\n",
    "\n",
    "class Leaf():\n",
    "    def __init__(self, y):\n",
    "        self.y = y\n",
    "\n",
    "    def __str__(self):\n",
    "        return 'Leaf y=%d' % self.y\n",
    "\n",
    "\n",
    "print(Node(0)), print(Leaf(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n"
     ]
    }
   ],
   "source": [
    "#打印树的方法\n",
    "def print_tree(node, prefix='', subfix=''):\n",
    "    prefix += '-' * 4\n",
    "    print(prefix, node, subfix)\n",
    "    if isinstance(node, Leaf):\n",
    "        return\n",
    "    for i in node.children:\n",
    "        subfix = 'value=' + str(i)\n",
    "        print_tree(node.children[i], prefix, subfix)\n",
    "\n",
    "\n",
    "print_tree(Node(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n",
      "-------- Node col=2 value=0\n",
      "------------ Leaf y=0 value=0\n",
      "------------ Leaf y=1 value=1\n",
      "------------ Leaf y=1 value=2\n",
      "-------- Node col=1 value=1\n",
      "------------ Leaf y=0 value=0\n",
      "------------ Node col=3 value=1\n",
      "---------------- Leaf y=1 value=0\n",
      "---------------- Leaf y=0 value=1\n",
      "-------- Leaf y=1 value=2\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "\n",
    "#加载树,剪枝用\n",
    "with open('tree.dump', 'rb') as fr:\n",
    "    root = pickle.load(fr)\n",
    "\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=3 \n",
      "-------- Leaf y=1 value=0\n",
      "-------- Leaf y=0 value=1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(array([[1, 1, 0, 1, 1, 1, 0],\n",
       "        [1, 1, 0, 0, 1, 1, 1]]),\n",
       " array([[1, 1, 0, 0, 1, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 0, 1]]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#先从最后一个节点开始切,也就是0=1,1=1\n",
    "#先把数据和测试数据筛选出来\n",
    "x_0_1 = x[x[:, 0] == 1]\n",
    "x_0_1_and_1_1 = x_0_1[x_0_1[:, 1] == 1]\n",
    "\n",
    "test_x_0_1 = test_x[test_x[:, 0] == 1]\n",
    "test_x_0_1_and_1_1 = test_x_0_1[test_x_0_1[:, 1] == 1]\n",
    "\n",
    "#取节点\n",
    "node = root.children[1].children[1]\n",
    "\n",
    "print_tree(node)\n",
    "x_0_1_and_1_1, test_x_0_1_and_1_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n",
      "-------- Node col=2 value=0\n",
      "------------ Leaf y=0 value=0\n",
      "------------ Leaf y=1 value=1\n",
      "------------ Leaf y=1 value=2\n",
      "-------- Node col=1 value=1\n",
      "------------ Leaf y=0 value=0\n",
      "------------ Leaf y=0 value=1\n",
      "-------- Leaf y=1 value=2\n"
     ]
    }
   ],
   "source": [
    "def after_cut(node, _x, test_x):\n",
    "    #先计算不剪切的测试正确率\n",
    "    pre_correct = 0\n",
    "    for split_value in np.unique(test_x[:, node.col]):\n",
    "        sub_test_x = test_x[test_x[:, node.col] == split_value]\n",
    "        sub_test_y = sub_test_x[:, -1]\n",
    "\n",
    "        pre_correct += np.sum(sub_test_y == node.children[split_value].y)\n",
    "\n",
    "    #计算剪切的测试正确率\n",
    "    #取y\n",
    "    _y = _x[:, -1]\n",
    "    test_y = test_x[:, -1]\n",
    "\n",
    "    #求众数\n",
    "    vote_y = np.bincount(_y).argmax()\n",
    "\n",
    "    #计算测试正确率\n",
    "    after_correct = np.sum(test_y == vote_y)\n",
    "\n",
    "    #如果剪切后的测试正确率上升或不变,就剪切,否则不剪切\n",
    "    if after_correct >= pre_correct:\n",
    "        return Leaf(y=vote_y)\n",
    "\n",
    "    return node\n",
    "\n",
    "\n",
    "root.children[1].children[1] = after_cut(node, x_0_1_and_1_1,\n",
    "                                         test_x_0_1_and_1_1)\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=1 \n",
      "-------- Leaf y=0 value=0\n",
      "-------- Leaf y=0 value=1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(array([[1, 0, 1, 0, 0, 0, 0],\n",
       "        [1, 0, 0, 0, 0, 0, 0],\n",
       "        [1, 1, 0, 1, 1, 1, 0],\n",
       "        [1, 1, 0, 0, 1, 1, 1]]),\n",
       " array([[1, 1, 0, 0, 1, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 0, 1]]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#剪切0=1\n",
    "#先把数据和测试数据筛选出来\n",
    "x_0_1 = x[x[:, 0] == 1]\n",
    "\n",
    "test_x_0_1 = test_x[test_x[:, 0] == 1]\n",
    "\n",
    "#取节点\n",
    "node = root.children[1]\n",
    "\n",
    "print_tree(node)\n",
    "x_0_1, test_x_0_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n",
      "-------- Node col=2 value=0\n",
      "------------ Leaf y=0 value=0\n",
      "------------ Leaf y=1 value=1\n",
      "------------ Leaf y=1 value=2\n",
      "-------- Leaf y=0 value=1\n",
      "-------- Leaf y=1 value=2\n"
     ]
    }
   ],
   "source": [
    "root.children[1] = after_cut(node, x_0_1, test_x_0_1)\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=2 \n",
      "-------- Leaf y=0 value=0\n",
      "-------- Leaf y=1 value=1\n",
      "-------- Leaf y=1 value=2\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(array([[0, 0, 0, 0, 0, 0, 0],\n",
       "        [0, 1, 0, 0, 1, 1, 0],\n",
       "        [0, 2, 2, 0, 2, 1, 1],\n",
       "        [0, 0, 1, 1, 1, 0, 1]]),\n",
       " array([[0, 0, 1, 0, 0, 0, 0],\n",
       "        [0, 1, 0, 1, 0, 0, 1]]))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#剪切0=0\n",
    "#先把数据和测试数据筛选出来\n",
    "x_0_0 = x[x[:, 0] == 0]\n",
    "\n",
    "test_x_0_0 = test_x[test_x[:, 0] == 0]\n",
    "\n",
    "#取节点\n",
    "node = root.children[0]\n",
    "\n",
    "print_tree(node)\n",
    "x_0_0, test_x_0_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n",
      "-------- Leaf y=0 value=0\n",
      "-------- Leaf y=0 value=1\n",
      "-------- Leaf y=1 value=2\n"
     ]
    }
   ],
   "source": [
    "root.children[0] = after_cut(node, x_0_0, test_x_0_0)\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "---- Node col=0 \n",
      "-------- Leaf y=0 value=0\n",
      "-------- Leaf y=0 value=1\n",
      "-------- Leaf y=1 value=2\n"
     ]
    }
   ],
   "source": [
    "#剪切根节点\n",
    "root = after_cut(root, x, test_x)\n",
    "print_tree(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.7\n",
      "-------------------------\n",
      "0.5714285714285714\n"
     ]
    }
   ],
   "source": [
    "#预测方法,测试\n",
    "def pred(_x, node):\n",
    "    col_value = _x[node.col]\n",
    "    node = node.children[col_value]\n",
    "\n",
    "    if isinstance(node, Leaf):\n",
    "        return node.y\n",
    "\n",
    "    return pred(_x, node)\n",
    "\n",
    "\n",
    "correct = 0\n",
    "for i in x:\n",
    "    if pred(i, root) == i[-1]:\n",
    "        correct += 1\n",
    "\n",
    "print(correct / len(x))\n",
    "\n",
    "print('-------------------------')\n",
    "\n",
    "correct = 0\n",
    "for i in test_x:\n",
    "    if pred(i, root) == i[-1]:\n",
    "        correct += 1\n",
    "\n",
    "print(correct / len(test_x))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
